Mulgrad

计算逐元素乘法 (Mul) 操作的梯度。该算子是 Mul 算子的反向传播(backward pass)部分。梯度的计算遵循链式法则。

\[ \begin{align}\begin{aligned}\text{dx0}_i = \text{dy}_i \times \text{Input1}_i\\\text{dx1}_i = \text{dy}_i \times \text{Input0}_i\end{aligned}\end{align} \]

其中 dx0 和 dx1 分别是损失函数对前向输入 Input0 和 Input1 的梯度。

Gradmul1L版本专门用于 x1 张量维度大于或等于 x2 张量的广播场景。Gradmul2l版本专门用于 x2 张量维度大于或等于 x1 张量的广播场景。

输入:
  • dy - 来自后一层的上游梯度张量。

  • x1 - 前向传播时的第一个输入张量(被除数)。

  • x2 - 前向传播时的第二个输入张量(除数)。

  • params - 参数打包成结构体Parameter:
    • tile_data0 - 临时工作空间地址。

    • tile_data1 - 临时工作空间地址。

    • large_shape - x1x2 中维度较大的张量的形状。

    • small_shape - x1x2 中维度较小的张量的形状。

    • out_shape - 输出张量 dx1dx2 的形状。

    • ndims - 张量的维度数。

    • dy_size - dy元素个数,需要初始化。

    • x1_size - x1元素个数,需要初始化。

    • x2_size - x2元素个数,需要初始化。

    • large_strides - 维度较大张量的步长信息。

    • small_strides - 维度较小张量的步长信息。

    • out_strides - 输出张量的步长信息。

    • large_multiples - 维度较大张量的广播倍数。

    • small_multiples - 维度较小张量的广播倍数。

    • indices - 用于广播计算的临时索引空间地址。

    • x1_shape - x1的维度信息地址。

    • x2_shape - x2的维度信息地址。

  • core_mask - 核掩码。

输出:
  • dx1 - 写入计算出的对 x1 的梯度。

  • dx2 - 写入计算出的对 x2 的梯度。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持fp32

  • MT7004 支持fp16, fp32

共享存储版本:

void fp_mul_grad_s(float *dy, float *dx1, float *dx2, float *x1, float *x2, Parameter *params, int core_mask)
void hp_mul_grad_s(half *dy, half *dx1, half *dx2, half *x1, half *x2, Parameter *params, int core_mask)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <divgrad.h>
 4int main(int argc, char* argv[]) {
 5    float *dy = (float *)0x81000000;//输入,初始化
 6    float *dx1 = (float *)0x82000000;//输出,需要初始化
 7    float *dx2 = (float *)0x83000000;//输出,需要初始化
 8    float *x1_data = (float *)0x84000000;//输入,需要初始化
 9    float *x2_data = (float *)0x85000000;//输入,需要初始化
10    float *tile_data0 = (float *)0x86000000;//中间结果,不需要初始化
11    float *tile_data1 = (float *)0x87000000;//中间结果,不需要初始化
12
13    long long ndims = 4;
14    long long dy_size;
15    long long x1_size;
16    long long x2_size;
17
18    int *large_strides = (int *)0x91000000;//不需要初始化
19    int *small_strides = (int *)0x91001000; //不需要初始化
20    int *out_strides = (int *)0x91002000; //不需要初始化
21    int *large_multiples = (int *)0x91003000; //不需要初始化
22    int *small_multiples = (int *)0x91004000; //不需要初始化
23    int *indices = (int *)0x91005000;
24    float *check_dx1 = (float *)0x91006000;
25    float *check_dx2 = (float *)0x1007000;
26
27    int i = 0;
28    srand(seed++);
29
30    //初始化
31    int x1_shape[4] = {8, 1, 8, 8};
32    int x2_shape[4] = {8, 8, 8, 8};
33
34    int* large_shape = (int*) x2_shape;
35    int* small_shape = (int*) x1_shape;
36    int *out_shape = (int*) x2_shape;
37
38    dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3];
39    x1_size = x1_shape[0] * x1_shape[1] * x1_shape[2] * x1_shape[3];
40    x2_size = x2_shape[0] * x2_shape[1] * x2_shape[2] * x2_shape[3];
41
42    for(i = 0; i < dy_size; ++i) {
43        dy[i] = (float)(rand()%20)/2;
44    }
45
46    for(i = 0; i < x1_size; ++i) {
47        x1_data[i] = (float)(rand()%20)/2;
48    }
49
50    for(i = 0; i < x2_size; ++i) {
51        x2_data[i] = (float)(rand()%20)/2;
52    }
53
54    memset(indices, 0, ndims*sizeof(int));
55
56    Parameter params;
57    params.tile_data0 = tile_data0;
58    params.tile_data1 = tile_data1;
59    params.large_shape = large_shape;
60    params.small_shape = small_shape;
61    params.out_shape = out_shape;
62    params.ndims = ndims;
63    params.dy_size = dy_size;
64    params.x1_size = x1_size;
65    params.x2_size = x2_size;
66    params.large_strides = large_strides;
67    params.small_strides = small_strides;
68    params.out_strides = out_strides;
69    params.large_multiples = large_multiples;
70    params.small_multiples = small_multiples;
71    params.indices = indices;
72    params.x1_shape = x1_shape;
73    params.x2_shape = x2_shape;
74
75    int core_mask = 0b1111;
76    /*性能统计*/
77    fp_mul_grad_s(dy, dx1, dx2, x1_data, x2_data, &params, core_mask);
78    return 0;
79}

私有存储版本:

void fp_mul_grad_p(float *dy, float *dx1, float *dx2, float *x1, float *x2, Parameter *params)
void hp_mul_grad_p(half *dy, half *dx1, half *dx2, half *x1, half *x2, Parameter *params)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <mulgrad.h>
 4int main(int argc, char* argv[]) {
 5    float *dy = (float *)0x10010000;//输入,初始化
 6    float *dx1 = (float *)0x10016000;//输出,需要初始化
 7    float *dx2 = (float *)0x10020000;//输出,需要初始化
 8    float *x1_data = (float *)0x10026000;//输入,需要初始化
 9    float *x2_data = (float *)0x10030000;//输入,需要初始化
10    float *tile_data0 = (float *)0x10036000;//中间结果,不需要初始化
11    float *tile_data1 = (float *)0x10040000;//中间结果,不需要初始化
12
13    long long ndims = 4;
14    long long dy_size;
15    long long x1_size;
16    long long x2_size;
17
18    int *large_strides = (int *)0x10050000;//不需要初始化
19    int *small_strides = (int *)0x10051000; //不需要初始化
20    int *out_strides = (int *)0x10052000; //不需要初始化
21    int *large_multiples = (int *)0x10053000; //不需要初始化
22    int *small_multiples = (int *)0x10054000; //不需要初始化
23    int *indices = (int *)0x10055000;
24    float *check_dx1 = (float *)0x10060000;
25    float *check_dx2 = (float *)0x10070000;
26
27    int i = 0;
28    srand(seed++);
29
30    /*
31    1024时 large_shape = {4, 8, 4, 8}, small_shape = {4, 8, 4, 8}
32    4096时 large_shape = {8, 8, 8, 8}, small_shape = {8, 8, 8, 8}
33    */
34    //初始化
35    int x1_shape[4] = {8, 1, 8, 8};
36    int x2_shape[4] = {8, 8, 8, 8};
37
38    int* large_shape = (int*) x2_shape;
39    int* small_shape = (int*) x1_shape;
40    int *out_shape = (int*) x2_shape;
41
42    dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3];
43    x1_size = x1_shape[0] * x1_shape[1] * x1_shape[2] * x1_shape[3];
44    x2_size = x2_shape[0] * x2_shape[1] * x2_shape[2] * x2_shape[3];
45
46    for(i = 0; i < dy_size; ++i) {
47        dy[i] = (float)(rand()%20)/2;
48    }
49
50    for(i = 0; i < x1_size; ++i) {
51        x1_data[i] = (float)(rand()%20)/2;
52    }
53
54    for(i = 0; i < x2_size; ++i) {
55        x2_data[i] = (float)(rand()%20)/2;
56    }
57
58    memset(indices, 0, ndims*sizeof(int));
59
60    Parameter params;
61
62    params.tile_data0 = tile_data0;
63    params.tile_data1 = tile_data1;
64    params.large_shape = large_shape;
65    params.small_shape = small_shape;
66    params.out_shape = out_shape;
67    params.ndims = ndims;
68    params.dy_size = dy_size;
69    params.x1_size = x1_size;
70    params.x2_size = x2_size;
71    params.large_strides = large_strides;
72    params.small_strides = small_strides;
73    params.out_strides = out_strides;
74    params.large_multiples = large_multiples;
75    params.small_multiples = small_multiples;
76    params.indices = indices;
77    params.x1_shape = x1_shape;
78    params.x2_shape = x2_shape;
79
80    fp_mul_grad_p(dy, dx1, dx2, x1_data, x2_data, &params);
81    return 0;
82}